#!/usr/bin/env python3
from __future__ import annotations

from collections import defaultdict
from typing import Callable, List, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.distributions as torchd
from rpi.helpers import to_torch
from rpi.helpers.data import yield_batch_infinitely
from torch import nn
from torch.nn import functional as F
import wandb
import sys
import numpy


from .base import Agent
# from .ppo import PPOAgent


class Ensemble(nn.Module):
    def __init__(self, make_nn: Callable, num_nns: int,
                 input_dim: int, out_dim: int,
                 obs_normalizer: Optional[nn.Module] = None,
                 std_from_means: bool = False,
                 var_func: Callable = F.softplus) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.var_func = var_func
        self.num_nns = num_nns
        self.nns = nn.ModuleList([
            make_nn()
            for _ in range(num_nns)
        ])

        if obs_normalizer is not None:
            self.obs_normalizer = obs_normalizer
        else:
            self.obs_normalizer = lambda obs, update: obs

        self.std_from_means = std_from_means

    def forward_all(self, inp, normalize_input: bool = False) -> List[torchd.Distribution]:
        distrs = [
            self.forward_single(inp, nn_idx, normalize_input=normalize_input)
            for nn_idx in range(self.num_nns)
        ]
        return distrs

    def forward_single(self, inp: torch.Tensor, nn_idx: int, normalize_input: bool = False):
        if normalize_input:
            normalizer = self.obs_normalizer
        else:
            # Pass thru
            normalizer = lambda obs, update: obs

        net = self.nns[nn_idx]
        mean_and_var = net(normalizer(inp, update=False))
        means = mean_and_var[:, :self.input_dim]
        pre_vars = mean_and_var[:, self.input_dim:]

        vars = self.var_func(pre_vars) + 1e-8
        # var = torch.exp(log_scale * 2)
        dist = torchd.Independent(
            torchd.Normal(loc=means, scale=torch.sqrt(vars)), 1
        )
        return dist

    @staticmethod
    def compute_stats(distributions: List[torchd.Distribution], name='stats', std_from_means: bool = False):
        from collections import namedtuple
        stats = namedtuple(name, 'mean std var upper lower all_means all_vars max_gap upper_max_gap lower_max_gap')

        # NOTE: How to aggregate predicted variances? --> (https://arxiv.org/pdf/1612.01474.pdf; Right before Section 3)
        all_means = torch.stack([distr.mean for distr in distributions], dim=0)
        mean = all_means.mean(0)

        if std_from_means:
            var = torch.var(all_means, dim=0)
            all_vars = torch.stack([var for _ in distributions], dim=0)
        else:
            all_vars = torch.stack([distr.variance for distr in distributions], dim=0)
            var = torch.stack([(distr.mean ** 2 + distr.variance) for distr in distributions], dim=0).mean(0) - mean ** 2

        std = torch.sqrt(var)
        upper = mean + std
        lower = mean - std

        max_gap = torch.max(all_means, dim=0)[0]- torch.min(all_means, dim=0)[0]
        upper_max_gap = mean + max_gap
        lower_max_gap = mean - max_gap

        return stats(mean=mean, std=std, var=var, upper=upper, lower=lower,
                     all_means=all_means, all_vars=all_vars,
                     max_gap=max_gap, upper_max_gap=upper_max_gap,lower_max_gap=lower_max_gap)

    def forward(self, inp, normalize_input: bool = False):
        distributions = self.forward_all(inp, normalize_input=normalize_input)
        stacked = torch.stack([distr.mean for distr in distributions], dim=0)
        return stacked.mean(0)

    def forward_stats(self, inp, normalize_input: bool = False):
        values = self.forward_all(inp, normalize_input=normalize_input)
        stats = self.compute_stats(values, std_from_means=self.std_from_means)
        return stats

class StatePredictorEnsemble(nn.Module):
    def __init__(self, make_nn: Callable, num_state_nns: int, state_dim: int, act_dim: int, obs_normalizer: Optional[nn.Module] = None, var_func=F.softplus, beta=1.0, std_from_means=False) -> None:
        """
        - beta (float): Adjust the UCB bound stddev (i.e., range is `mean + beta * std`)
        """
        super().__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.var_func = var_func
        self.nns = nn.ModuleList([
            make_nn()
            for _ in range(num_state_nns)
        ])

        if obs_normalizer is not None:
            self.obs_normalizer = obs_normalizer
        else:
            self.obs_normalizer = lambda obs, update: obs

        self.beta = beta
        self.num_state_nns = num_state_nns

        self.std_from_means = std_from_means

    @staticmethod
    def compute_stats(distributions: List[torchd.Distribution], beta=1.0, name='state_stats', std_from_means: bool = False):
        from collections import namedtuple
        stats = namedtuple(name, 'mean std var upper lower all_means all_vars max_gap upper_max_gap lower_max_gap')

        # NOTE: How to aggregate predicted variances? --> (https://arxiv.org/pdf/1612.01474.pdf; Right before Section 3)
        all_means = torch.stack([distr.mean for distr in distributions], dim=0)
        mean = all_means.mean(0)

        if std_from_means:
            var = torch.var(all_means, dim=0)
            all_vars = torch.stack([var for _ in distributions], dim=0)
        else:
            all_vars = torch.stack([distr.variance for distr in distributions], dim=0)
            var = torch.stack([(distr.mean ** 2 + distr.variance) for distr in distributions], dim=0).mean(0) - mean ** 2
            # var = (stddev_coef ** 2) * torch.var(all_means, dim=0)
        std = torch.sqrt(var)

        upper = mean + beta * std
        lower = mean - beta * std

        max_gap = torch.max(all_means, dim=0)[0]- torch.min(all_means, dim=0)[0]
        upper_max_gap = mean + max_gap
        lower_max_gap = mean - max_gap

        return stats(mean=mean, std=std, var=var, upper=upper, lower=lower,
                     all_means=all_means, all_vars=all_vars,
                    max_gap=max_gap, upper_max_gap=upper_max_gap,lower_max_gap=lower_max_gap)

    def forward_all(self, obs, action, normalize_input: bool = False) -> List[torchd.Distribution]:
        distributions = [self.forward_single(obs, action, vfn_idx, normalize_input=normalize_input) for vfn_idx in range(self.num_state_nns)]
        return distributions

    def forward_single(self, obs, action, state_nn_idx, normalize_input=False):
        assert len(obs.shape) == 2, f'obs has invalid shape: {obs.shape}'

        if normalize_input:
            normalizer = self.obs_normalizer
        else:
            # Pass thru
            normalizer = lambda obs, update: obs

        state_nn = self.nns[state_nn_idx]
        mean_and_var = state_nn(torch.concat((normalizer(obs, update=False), action), dim=1))
        # mean, pre_var = torch.chunk(mean_and_var, 2, dim=1)
        means = mean_and_var[:, :self.state_dim]
        pre_vars = mean_and_var[:, self.state_dim:]

        vars = self.var_func(pre_vars) + 1e-8
        # var = torch.exp(log_scale * 2)
        dist = torchd.Independent(
            torchd.Normal(loc=means, scale=torch.sqrt(vars)), 1
        )
        return dist

    def forward(self, obs, action, normalize_input: bool = False):
        distributions = self.forward_all(obs, action, normalize_input=normalize_input)
        # return torch.stack([distr.mean for distr in distributions], dim=0).mean(0)
        stacked = torch.stack([distr.mean for distr in distributions], dim=0)
        # print('stacked', stacked.shape)
        return stacked.mean(0)

    def forward_stats(self, obs, action, normalize_input: bool = False):
        values = self.forward_all(obs, action, normalize_input=normalize_input)
        stats = self.compute_stats(values, beta=self.beta, std_from_means=self.std_from_means)
        return stats


class ValueEnsemble(nn.Module):
    def __init__(self, make_vfn: Callable, num_value_nns=5, obs_normalizers=None, var_func=F.softplus, beta=1.0, std_from_means=False) -> None:
        """
        - beta (float): Adjust the UCB bound stddev (i.e., range is `mean + beta * std`)
        """
        super().__init__()
        self.var_func = var_func
        self.vfns = nn.ModuleList([
            make_vfn()
            for _ in range(num_value_nns)
        ])

        if obs_normalizers is not None:
            assert len(obs_normalizers) == len(self.vfns)
            self.obs_normalizers = obs_normalizers
        else:
            self.obs_normalizers = [lambda obs, update: obs for _ in self.vfns]

        self.beta = beta
        self.num_value_nns = num_value_nns

        self.stddev_coef = 1.0
        self.std_from_means = std_from_means

    @staticmethod
    def compute_stats(distributions: List[torchd.Distribution], beta=1.0, name='value_stats', stddev_coef: float = 1.0, std_from_means: bool = False):
        from collections import namedtuple
        stats = namedtuple(name, 'mean std upper lower all_means all_vars max_gap upper_max_gap lower_max_gap')

        # NOTE: How to aggregate predicted variances? --> (https://arxiv.org/pdf/1612.01474.pdf; Right before Section 3)
        all_means = torch.stack([distr.mean for distr in distributions], dim=0)
        mean = all_means.mean(0)

        if std_from_means:
            var = torch.var(all_means, dim=0)
            all_vars = torch.stack([var for _ in distributions], dim=0)
        else:
            all_vars = torch.stack([distr.variance for distr in distributions], dim=0)
            var = torch.stack([(distr.mean ** 2 + distr.variance) for distr in distributions], dim=0).mean(0) - mean ** 2
            # var = (stddev_coef ** 2) * torch.var(all_means, dim=0)
            var = (stddev_coef ** 2) * var
        std = torch.sqrt(var)

        upper = mean + beta * std
        lower = mean - beta * std

        max_gap = torch.max(all_means, dim=0)[0]- torch.min(all_means, dim=0)[0]
        upper_max_gap = mean + max_gap
        lower_max_gap = mean - max_gap

        return stats(mean=mean, std=std, upper=upper, lower=lower,
                     all_means=all_means, all_vars=all_vars,
                     max_gap=max_gap, upper_max_gap=upper_max_gap,lower_max_gap=lower_max_gap)

    def forward_all(self, obs, normalize_input: bool = False) -> List[torchd.Distribution]:
        distributions = [self.forward_single(obs, vfn_idx, normalize_input=normalize_input) for vfn_idx in range(self.num_value_nns)]
        return distributions

    def forward_single(self, obs, vfn_idx, normalize_input=False):
        assert len(obs.shape) == 2, f'obs has invalid shape: {obs.shape}'

        if normalize_input:
            normalizer = self.obs_normalizers[vfn_idx]
        else:
            normalizer = lambda obs, update: obs

        vfn = self.vfns[vfn_idx]
        mean_and_var = vfn(normalizer(obs, update=False))
        mean, pre_var = torch.chunk(mean_and_var, 2, dim=1)

        var = self.var_func(pre_var) + 1e-8
        # var = torch.exp(log_scale * 2)
        dist = torchd.Independent(
            torchd.Normal(loc=mean, scale=torch.sqrt(var)), 1
        )
        return dist

    def forward(self, obs, normalize_input: bool = False):
        distributions = self.forward_all(obs, normalize_input=normalize_input)
        # return torch.stack([distr.mean for distr in distributions], dim=0).mean(0)
        stacked = torch.stack([distr.mean for distr in distributions], dim=0)
        # print('stacked', stacked.shape)
        return stacked.mean(0)

    def forward_stats(self, obs, normalize_input: bool = False):
        values = self.forward_all(obs, normalize_input=normalize_input)
        stats = self.compute_stats(values, beta=self.beta, stddev_coef=self.stddev_coef, std_from_means=self.std_from_means)
        return stats


class MaxValueFn(nn.Module):
    def __init__(self, value_fns: List[nn.Module],learner_fns=None, obs_normalizers=None) -> None:
        super().__init__()
        self.value_fns = value_fns
        self.learner_fns = learner_fns

        if obs_normalizers is not None:
            assert len(obs_normalizers) == len(value_fns)
            self.obs_normalizers = obs_normalizers
        else:
            self.obs_normalizers = [lambda obs, update: obs for _ in value_fns] ##NOTE??

    def forward(self, x, normalize_input=False):
        # x.shape: torch.Size([batch_size, obs_dim])
        if normalize_input:
            normalizers = self.obs_normalizers
        else:
            normalizers = [lambda obs, update: obs for _ in self.value_fns]

        values = torch.stack([vfn(normalizer(x, update=False)) for normalizer, vfn in zip(normalizers, self.value_fns)], dim=-1)
        max_obj = torch.max(values, dim=-1)
        # print("max_obj.values1:",max_obj.values)
        return max_obj.values



class MaxValueFnPlus(nn.Module):
    def __init__(self, value_fns: List[nn.Module],learner_fns=None, obs_normalizers=None,expert=None,learner=None, switch_rl_round=999999,state_in_distribution=999999,num_train_steps=100,explore_decay_rate=0) -> None:
        super().__init__()
        self.value_fns = value_fns
        self.learner_fns = learner_fns
        self.expert=expert
        self.learner=learner
        self.switch_rl_round=switch_rl_round
        self.state_in_distribution=state_in_distribution
        self.num_train_steps=num_train_steps
        self.explore_decay_rate=explore_decay_rate

        if obs_normalizers is not None:
            assert len(obs_normalizers) == len(value_fns)
            self.obs_normalizers = obs_normalizers
        else:
            self.obs_normalizers = [lambda obs, update: obs for _ in value_fns] ##NOTE??

    def forward(self, x, normalize_input=False, itr=None):
        # # x.shape: torch.Size([batch_size, obs_dim])
        # if normalize_input:
        #     normalizers = self.obs_normalizers
        # else:
        #     normalizers = [lambda obs, update: obs for _ in self.value_fns]
        if itr is None:
            raise ValueError("itr is None")
        print("   ")
        print("##########MaxValueFnPlus###########")
        print("##EXPERT:",self.expert,", LEARNER:",self.learner)
        print("##ITR:",itr,", EXPLORE_DECAY_RATE:",self.explore_decay_rate, ", SWITCHING TO PURE RL ROUND:",self.switch_rl_round)
        arr=[]
        #experts
        if(np.random.random()>itr/self.num_train_steps*self.explore_decay_rate):

            if  itr < self.switch_rl_round and self.expert is not None and self.expert != "NONE" and len(self.value_fns)>0:
                
                print("+MaxValueFnPlus: using EXPERTS:",self.expert,", state_in_distribution:",self.state_in_distribution)

                for _func in self.value_fns:

                    upper_=_func.forward_stats(x,normalize_input=True).upper
                    lower_=_func.forward_stats(x,normalize_input=True).lower
                    mean_=_func.forward_stats(x,normalize_input=True).mean
                    print("+EXPERT upper and lower GAP:",(upper_-lower_).cpu().numpy().ravel())
                    mask=upper_-lower_ > self.state_in_distribution
                    mean_[mask]=-999
                    upper_[mask]=-999
                    lower_[mask]=-999

                    if self.expert == "ucb":
                        tmp=upper_                
                        print("+MaxValueFnPlus:expert:ucb:",tmp.cpu().numpy().ravel())
                    elif self.expert == "lcb":
                        tmp=lower_
                        print("+MaxValueFnPlus:expert:lcb:",tmp.cpu().numpy().ravel())
                    elif self.expert == "mean":
                        tmp=mean_
                        print("+MaxValueFnPlus:expert:mean:",tmp.cpu().numpy().ravel())
                    else:
                        print("+self.expert",self.expert)
                        print("+self.expert type:",type(self.expert))
                        raise ValueError("+MaxValueFnPlus: expert type not supported")
                    
                    arr.append(tmp)
            else:
                print("+MaxValueFnPlus not using experts, itr:",itr)
 
        #learner
        if self.learner_fns is not None and self.learner is not None and self.learner != "NONE":
            print("-MaxValueFnPlus: using LEARER:", self.learner)
            if self.learner == "ucb":
                tmp_=self.learner_fns.forward_stats(x,normalize_input=True).upper
                print("-MaxValueFnPlus:learner:ucb:",tmp_.cpu().numpy().ravel())
            elif self.learner == "lcb":
                tmp_=self.learner_fns.forward_stats(x,normalize_input=True).lower
                print("-MaxValueFnPlus:learner:lcb:",tmp_.cpu().numpy().ravel())
            elif self.learner == "mean":
                tmp_=self.learner_fns.forward_stats(x,normalize_input=True).mean
                print("-MaxValueFnPlus:learner:mean:",tmp_.cpu().numpy().ravel())
            else:
                raise ValueError("-MaxValueFnPlus: learner type not supported")

            arr.append(tmp_)
        else:
            print("-MaxValueFnPlus not using learner")

        if len(arr)==0:
            raise ValueError("MaxValueFnPlus: error")

        arr=torch.stack(arr)
        max_obj = torch.max(arr,0)
        
        selections=(max_obj.indices).cpu().numpy().ravel()
        
        # numpy.set_printoptions(threshold=sys.maxsize)
        # print("MaxValueFnPlus Expert or Learner SELECTION IDX:", selections)
        # numpy.set_printoptions(threshold=100)

        experts_n = 0
        if len(self.value_fns)>0:
            experts_n=experts_n+len(self.value_fns)
        
        hist=np.histogram(selections,bins=list(range(experts_n+2)))
        print("MaxValueFnPlus EXPERT LEARNER SELECTION IDX(ignore last bin):", hist)
        
        print("###################################")
        print("   ")
        # print("max_obj.values2:",max_obj.values)
        return max_obj.values



        

#value based approach
#NOTE X: select the lcb of learner and ucb of experts
class ActivePolicySelectorPlus:
    # NOTE: Currently the exact same calculation is done in ActiveStateExplorer, thus this one is not used.
    def __init__(self, value_fns: List[ValueEnsemble], value_learner_fn=None,itr=None,num_train_steps=None,expert=None,learner=None) -> None:
    # def __init__(self, value_learner_fn, value_fns: List[ValueEnsemble]) -> None:
        self.value_fns = value_fns
        self.value_learner_fn = value_learner_fn
        self.itr=itr
        self.num_train_steps=num_train_steps
        self.expert=expert
        self.learner=learner


    def _get_best_expert(self, obs):
        # Find the value function whose upper bound is the best
        if len(self.value_fns)==0 and self.value_learner_fn is None:
            print("warning! no experts for APS")
            exit()

        best_idx=None
        best_valobj=None

        if len(self.value_fns) > 0 and self.expert is not None and self.expert != "NONE": 
            if self.expert == "ucb":
                sorted_pairs = sorted([(idx, vfn.forward_stats(obs)) for idx, vfn in enumerate(self.value_fns)], key=lambda x: x[1].upper)
                best_idx, best_valobj = sorted_pairs[-1]
            elif self.expert == "lcb":
                sorted_pairs = sorted([(idx, vfn.forward_stats(obs)) for idx, vfn in enumerate(self.value_fns)], key=lambda x: x[1].lower)
                best_idx, best_valobj = sorted_pairs[-1]
            elif self.expert == "mean":
                sorted_pairs = sorted([(idx, vfn.forward_stats(obs)) for idx, vfn in enumerate(self.value_fns)], key=lambda x: x[1].mean)
                best_idx, best_valobj = sorted_pairs[-1]
            else:
                print("aps expert not defined")
                exit()

        if self.value_learner_fn is not None and self.learner is not None and self.learner != "NONE":

            if self.learner == "ucb":
                learner_lcb_val=self.value_learner_fn.forward_stats(obs).upper
                learner_state_val = self.value_learner_fn.forward_stats(obs)
            elif self.learner == "lcb":
                learner_lcb_val=self.value_learner_fn.forward_stats(obs).lower
                learner_state_val = self.value_learner_fn.forward_stats(obs)
            elif self.learner == "mean":
                learner_lcb_val=self.value_learner_fn.forward_stats(obs).mean
                learner_state_val = self.value_learner_fn.forward_stats(obs)
            else:
                print("aps learner not defined")
                print(self.learner)
                # exit()
                raise ValueError("aps learner not defined")


            if best_valobj is None:
                print("best_valobj None", -1, learner_lcb_val)
                return 0, learner_state_val

            if best_valobj.upper  < learner_lcb_val :
            # if best_valobj.upper * (1-self.itr/self.num_train_steps*2) < learner_lcb_val *(self.itr/self.num_train_steps*2):
                print("best_valobj.upper < learner_lcb_val", -1, learner_lcb_val)
                return len(self.value_fns), learner_state_val
                #NOTE??: need to fix the index
        
        if best_idx is None:
            raise ValueError("APS plus: best_idx:No value function found")

        return best_idx, best_valobj

    @torch.no_grad()
    def select(self, obs):
        obs = to_torch(obs).unsqueeze(0)
        best_idx, best_valobj = self._get_best_expert(obs)

        return best_idx, best_valobj


def calc_loss_actor(pi: nn.Module, obs: np.ndarray, action: np.ndarray, tgt_val: np.ndarray, log_probs_old: np.ndarray, clip_eps: float = 0.2):
    """Mamba actor loss (weighted policy gradient)"""

    pi_ratio = torch.exp(pi(obs).log_prob(action) - log_probs_old)
    return - (pi_ratio * tgt_val).mean()




class MambaAgent(Agent):
    def __init__(self, pi: nn.Module, vfn_aggr: nn.Module, optimizer, obs_normalizer: Callable | None = None, max_grad_norm: None | float = None, standardize_advantages: bool = True,
                 gamma: float = 1., lambd: float = 0.9, use_ppo_loss: bool = False) -> None:
        """Mamba learner agent.

        This is a simple agent with these properties:
        - Its value function vfn_aggr is max over expert value functions (f^{max})
        - update() method only updates its policy \\pi, using (importance weighted) policy gradient
        """
        super().__init__(pi, vfn_aggr, optimizer, obs_normalizer)

        self.max_grad_norm = max_grad_norm
        self.standardize_advantages = standardize_advantages

        self.gamma = gamma
        self.lambd = lambd

        self.vfn_aggr = self.vfn 

        self.use_ppo_loss = use_ppo_loss

    def update(self, transitions: Sequence[dict], num_epochs: int = 10, batch_size: int = 128):
        if self.standardize_advantages:
            advs = [tr['adv'] for tr in transitions]
            std_advs, mean_advs = torch.std_mean(to_torch(advs), unbiased=False)

        sample_itr = yield_batch_infinitely(transitions, batch_size)
        loss_dict = defaultdict(list)

        num_updates = max(len(transitions) // batch_size, 1) * num_epochs
        for _ in range(num_updates):
            batch = next(sample_itr)

            states = batch['state'].type(torch.float32)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)

            actions = batch['action'].type(torch.float32)
            distribs = self.pi(states)

            advs = batch['adv'].type(torch.float32)
            if self.standardize_advantages:
                advs = (advs - mean_advs) / (std_advs + 1e-8)

            log_probs_old = batch['log_prob'].type(torch.float32)

            self.optimizer.zero_grad()
            if self.use_ppo_loss:
                from .ppo import calc_loss_actor as calc_ppo_loss_actor
                loss = calc_ppo_loss_actor(self.pi, states, actions, advs, log_probs_old)
            else:
                loss = calc_loss_actor(self.pi, states, actions, advs, log_probs_old)
            print('loss', loss)

            if loss is not None:
                loss.backward()
                if self.max_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.pi.parameters(), self.max_grad_norm)
                self.optimizer.step()

                # Append to loss dict
                loss_dict['actor'].append(loss.item())
                loss_dict['entropy'].append(torch.mean(distribs.entropy()).item())

            self.n_updates += 1

        loss_info = {
            f'loss/{key}': np.asarray(arr).mean() for key, arr in loss_dict.items()
        }

        return loss_info



#NOTE??: why RPI AGNET using max func plus?
# from rpi.agents.ppo import calc_loss_critic
from .ppo import PPOAgent2
class RPIAgent(PPOAgent2):
    def __init__(self, pi: nn.Module, vfn_learner: nn.Module, vfn_aggr_plus: nn.Module, optimizer,
                 obs_normalizer: Callable | None = None, max_grad_norm: None | float = None,
                 standardize_advantages: bool = True, gamma: float = 1., lambd: float = 0.9, use_ppo_loss: bool = False) -> None:
        super().__init__(pi, vfn_learner, optimizer, obs_normalizer)
#NOTE??: mamba using vfn_aggr as init , RPI using vfn_learner, vfn_aggr is not trained
#       Mamba

        self.gamma = gamma
        self.lambd = lambd
        self.vfn_aggr_plus = vfn_aggr_plus  # Max over all vfns including learner's one

    # def update(self, transitions, num_epochs=10, batch_size: int = 128):
    #     from rpi.agents.ppo import calc_loss_critic
    #     if self.standardize_advantages:
    #         advs = [tr['adv'] for tr in transitions]
    #         std_advs, mean_advs = torch.std_mean(to_torch(advs), unbiased=False)

    #     sample_itr = yield_batch_infinitely(transitions, batch_size=batch_size)
    #     loss_dict = defaultdict(list)

    #     num_updates = max(len(transitions) // batch_size, 1) * num_epochs

    #     for epoch in range(num_updates):
    #         batch = next(sample_itr)

    #         states = batch['state'].type(torch.float32)
    #         if self.obs_normalizer:
    #             states = self.obs_normalizer(states, update=False)

    #         actions = batch['action'].type(torch.float32)
    #         distribs = self.pi(states)
    #         # vs_pred = self.critic.value_nn(states)

    #         advs = batch['adv'].type(torch.float32)
    #         if self.standardize_advantages:
    #             advs = (advs - mean_advs) / (std_advs + 1e-8)

    #         log_probs_old = batch['log_prob'].type(torch.float32)
    #         vs_pred_old = batch['v_pred'].type(torch.float32)
    #         vs_teacher = batch['v_learner_teacher'].type(torch.float32)


    #         # Same shape as vs_pred: (batch_size, 1)
    #         vs_pred_old = vs_pred_old[..., None]
    #         vs_teacher = vs_teacher[..., None]

    #         self.optimizer.zero_grad()
    #         loss_actor = calc_loss_actor(self.pi, states, actions, advs, log_probs_old)
    #         if loss_actor is None:
    #             loss_actor = 0

    #         loss_critic = calc_loss_critic(self.vfn, states, vs_teacher)
    #         #NOTE??: why loss_entropy
    #         loss_entropy = -torch.mean(distribs.entropy())
    #         loss = (
    #             loss_actor
    #             + self.coef_critic * loss_critic
    #             + self.coef_entropy * loss_entropy
    #         )
    #         loss.backward()
    #         if self.max_grad_norm is not None:
    #             torch.nn.utils.clip_grad_norm_(
    #                 list(self.pi.parameters()) + list(self.vfn.parameters()), self.max_grad_norm
    #             )
    #         self.optimizer.step()

    #         # Append to loss dict
    #         loss_dict['actor'].append(loss_actor.item())
    #         loss_dict['critic'].append(loss_critic.item())
    #         loss_dict['entropy'].append(loss_entropy.item())
    #         loss_dict['all'].append(loss.item())


    #         self.n_updates += 1

    #     loss_info = {
    #         f'loss/{key}': np.asarray(arr).mean() for key, arr in loss_dict.items()
    #     }

    #     return loss_info




class ActiveStateExplorer:
    #uncertainty:{"std", "max_gap"}
    def __init__(self, value_fns: List[ValueEnsemble], sigma, uncertainty="std") -> None:

        self.value_fns = value_fns
        self.sigma = sigma
        self._uncertainty = uncertainty

    def _get_best_expert(self, obs):
        # Find the value function whose upper bound is the best
        sorted_pairs = sorted([(idx, vfn.forward_stats(obs)) for idx, vfn in enumerate(self.value_fns)], key=lambda x: x[1].upper)
        best_idx, best_valobj = sorted_pairs[-1]

        return best_idx, best_valobj

    # Grrrr... It's so dirty...
    @torch.no_grad()
    def should_explore(self, obs):
        """When a (best) expert on current state is sure about what will happen, we will not gain anything by switching to expert,
        thus we should 'explore', by keep running the learner policy.
        """
        obs = to_torch(obs).unsqueeze(0)
        best_idx, best_valobj = self._get_best_expert(obs)
        
        if self._uncertainty == "std":
            explore=best_valobj.std < self.sigma
            return explore, best_idx, best_valobj, best_valobj.std
        elif self._uncertainty == "max_gap":
            explore=best_valobj.max_gap < self.sigma
            return explore, best_idx, best_valobj, best_valobj.max_gap
        else:
            raise ValueError(f"Invalid uncertainty type: {self._uncertainty}")
        




class ActivePolicySelector:
    # NOTE: Currently the exact same calculation is done in ActiveStateExplorer, thus this one is not used.
    def __init__(self, value_fns: List[ValueEnsemble], value_learner_fn=None,itr=None,num_train_steps=None) -> None:
    # def __init__(self, value_learner_fn, value_fns: List[ValueEnsemble]) -> None:
        self.value_fns = value_fns
        self.value_learner_fn = value_learner_fn
        self.itr=itr
        self.num_train_steps=num_train_steps

    def _get_best_expert(self, obs):
        # Find the value function whose upper bound is the best
        if len(self.value_fns)==0 and self.value_learner_fn is None:
            print("warning! no experts for APS")
            exit()

        best_idx=None
        best_valobj=None

        if len(self.value_fns) > 0: 
            sorted_pairs = sorted([(idx, vfn.forward_stats(obs)) for idx, vfn in enumerate(self.value_fns)], key=lambda x: x[1].upper)
            best_idx, best_valobj = sorted_pairs[-1]

        if self.value_learner_fn is not None:
            learner_lcb_val=self.value_learner_fn.forward_stats(obs).lower
            learner_state_val = self.value_learner_fn.forward_stats(obs)

            if best_valobj is None:
                print("best_valobj None", 0, learner_lcb_val)
                return 0, learner_state_val

            if best_valobj.upper  < learner_lcb_val :
            # if best_valobj.upper * (1-self.itr/self.num_train_steps*2) < learner_lcb_val *(self.itr/self.num_train_steps*2):
                print("best_valobj.upper < learner_lcb_val", 0, learner_lcb_val)
                return len(self.value_fns), learner_state_val

        return best_idx, best_valobj

    @torch.no_grad()
    def select(self, obs):
        obs = to_torch(obs).unsqueeze(0)
        best_idx, best_valobj = self._get_best_expert(obs)

        return best_idx, best_valobj
